"""
Stdio MCP Client - adapted from Letta with working implementation
"""

import asyncio
import sys
from contextlib import asynccontextmanager
from typing import Optional

import anyio
import anyio.lowlevel
import mcp.types as types
from anyio.streams.text import TextReceiveStream
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import get_default_environment

from .base_client import BaseMCPClient, BaseAsyncMCPClient
from .types import StdioServerConfig

logger = __import__('logging').getLogger(__name__)


class StdioMCPClient(BaseMCPClient):
    """MCP client using stdio transport"""

    def _initialize_connection(self, server_config: StdioServerConfig, timeout: float) -> bool:
        """Initialize stdio connection to MCP server"""
        try:
            server_params = StdioServerParameters(
                command=server_config.command,
                args=server_config.args,
                env=server_config.env or get_default_environment()
            )
            
            stdio_cm = forked_stdio_client(server_params)
            stdio_transport = self.loop.run_until_complete(
                asyncio.wait_for(stdio_cm.__aenter__(), timeout=timeout)
            )
            self.stdio, self.write = stdio_transport
            self.cleanup_funcs.append(
                lambda: self.loop.run_until_complete(stdio_cm.__aexit__(None, None, None))
            )

            session_cm = ClientSession(self.stdio, self.write)
            self.session = self.loop.run_until_complete(
                asyncio.wait_for(session_cm.__aenter__(), timeout=timeout)
            )
            self.cleanup_funcs.append(
                lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None))
            )
            return True
            
        except asyncio.TimeoutError:
            logger.error(f"Timed out while establishing stdio connection (timeout={timeout}s).")
            return False
        except Exception as e:
            logger.exception(f"Exception occurred while initializing stdio client session: {e}")
            return False


class AsyncStdioMCPClient(BaseAsyncMCPClient):
    """Async MCP client using stdio transport"""

    async def _initialize_connection(self, server_config: StdioServerConfig, timeout: float) -> bool:
        """Asynchronously initialize stdio connection"""
        try:
            server_params = StdioServerParameters(
                command=server_config.command,
                args=server_config.args,
                env=server_config.env or get_default_environment()
            )
            
            stdio_cm = forked_stdio_client(server_params)
            stdio_transport = await stdio_cm.__aenter__()
            self.stdio, self.write = stdio_transport
            self.cleanup_funcs.append(lambda: stdio_cm.__aexit__(None, None, None))

            session_cm = ClientSession(self.stdio, self.write)
            self.session = await session_cm.__aenter__()
            self.cleanup_funcs.append(lambda: session_cm.__aexit__(None, None, None))
            return True
            
        except Exception as e:
            logger.exception(f"Exception occurred while initializing async stdio client session: {e}")
            return False


@asynccontextmanager
async def forked_stdio_client(server: StdioServerParameters):
    """
    Client transport for stdio: this will connect to a server by spawning a
    process and communicating with it over stdin/stdout.
    """
    read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
    write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

    try:
        process = await anyio.open_process(
            [server.command, *server.args],
            env=server.env or get_default_environment(),
            stderr=sys.stderr,
        )
    except OSError as exc:
        raise RuntimeError(f"Failed to spawn process: {server.command} {server.args}") from exc

    async def stdout_reader():
        """Read from process stdout and parse JSON-RPC messages"""
        assert process.stdout, "Opened process is missing stdout"
        buffer = ""
        try:
            async with read_stream_writer:
                async for chunk in TextReceiveStream(
                    process.stdout,
                    encoding=server.encoding,
                    errors=server.encoding_error_handler,
                ):
                    lines = (buffer + chunk).split("\n")
                    buffer = lines.pop()
                    for line in lines:
                        try:
                            message = types.JSONRPCMessage.model_validate_json(line)
                        except Exception as exc:
                            await read_stream_writer.send(exc)
                            continue
                        await read_stream_writer.send(message)
        except anyio.ClosedResourceError:
            await anyio.lowlevel.checkpoint()

    async def stdin_writer():
        """Write JSON-RPC messages to process stdin"""
        assert process.stdin, "Opened process is missing stdin"
        try:
            async with write_stream_reader:
                async for message in write_stream_reader:
                    json_data = message.model_dump_json(by_alias=True, exclude_none=True)
                    await process.stdin.send(
                        (json_data + "\n").encode(
                            encoding=server.encoding,
                            errors=server.encoding_error_handler,
                        )
                    )
        except anyio.ClosedResourceError:
            await anyio.lowlevel.checkpoint()

    async def watch_process_exit():
        """Monitor process exit and raise error if non-zero exit code"""
        returncode = await process.wait()
        if returncode != 0:
            raise RuntimeError(f"Subprocess exited with code {returncode}. Command: {server.command} {server.args}")

    async with anyio.create_task_group() as tg, process:
        tg.start_soon(stdout_reader)
        tg.start_soon(stdin_writer)
        tg.start_soon(watch_process_exit)

        # Give the process a moment to start up
        with anyio.move_on_after(0.2):
            await anyio.sleep_forever()

        yield read_stream, write_stream